package com.ken.ability.gateway.filter.limit;

import cn.hutool.extra.spring.SpringUtil;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;

import java.util.Collections;
import java.util.concurrent.TimeUnit;

/**
 * 令牌桶的类 - TokenTong的对象就相当于一个令牌桶
 */
public class TokenTong {

    //令牌桶的key
    private String key;
    //最大令牌数
    private Integer maxTokens;
    //拥有的令牌数
    private Integer hasTokens;
    //令牌的生成速率 n/s - 每秒生成多少令牌
    private Integer secTokens;

    private StringRedisTemplate redisTemplate;

    //初始化令牌桶的lua脚本
    private String initTokenLua = "--接收令牌桶参数\n" +
            "--令牌桶的key \n" +
            "local key = \"tokenlimiter_\"..KEYS[1]\n" +
            "\n" +
            "--判断令牌桶是否存在\n" +
            "local flag = redis.call(\"exists\", key)\n" +
            "if flag == 0 then\n" +
            "    --拥有的令牌数\n" +
            "    local hasTokens = ARGV[1]\n" +
            "    --最大令牌数\n" +
            "    local maxTokens = ARGV[2]\n" +
            "    --每秒生成的令牌数\n" +
            "    local secTokens = ARGV[3]\n" +
            "\n" +
            "    --获取当前时间\n" +
            "    local nowTime = redis.call('time')\n" +
            "    --获取当前时间的微秒值\n" +
            "    local nowks = nowTime[1] * 1000000 + nowTime[2]\n" +
            "\n" +
            "    --令牌桶不存在，创建令牌桶\n" +
            "    redis.call(\"hmset\", key, \"hasTokens\", hasTokens, \"maxTokens\", maxTokens, \"secTokens\", secTokens, \"nextTime\", nowks)\n" +
            "    return 1\n" +
            "end\n" +
            "\n" +
            "return 0\n";

    //获取令牌的lua脚本
    private String getTokenLua = "--获取令牌的脚本\n" +
            "--令牌桶的key \n" +
            "local key = \"tokenlimiter_\"..KEYS[1]\n" +
            "--令牌桶拥有的令牌数\n" +
            "local hasTokens = tonumber(redis.call(\"hget\", key, \"hasTokens\"))\n" +
            "--令牌桶最大令牌数\n" +
            "local maxTokens= tonumber(redis.call(\"hget\", key, \"maxTokens\"))\n" +
            "--每秒生成令牌数\n" +
            "local secTokens = tonumber(redis.call(\"hget\", key, \"secTokens\"))\n" +
            "--最后生成令牌的时间\n" +
            "local nextTime = tonumber(redis.call(\"hget\", key, \"nextTime\"))\n" +
            "\n" +
            "--每个令牌生成的耗时 - 微秒\n" +
            "local singleTokenTime = 1000000 / secTokens\n" +
            "\n" +
            "--当前的时间\n" +
            "local now = redis.call('time')\n" +
            "local nowTime = now[1] * 1000000 + now[2]\n" +
            "\n" +
            "--计算令牌的生成数量\n" +
            "if nowTime > nextTime then\n" +
            "   --生成令牌\n" +
            "   --经过的时间\n" +
            "   local times = nowTime - nextTime\n" +
            "   --计算生成了多少块令牌\n" +
            "   local createTokens = times /  singleTokenTime\n" +
            "   if createTokens > 0 then\n" +
            "       --拥有的令牌\n" +
            "       hasTokens = math.min(hasTokens + createTokens, maxTokens)\n" +
            "       --下一次能计算令牌的时间\n" +
            "       nextTime = nowTime\n" +
            "   end\n" +
            "end\n" +
            "\n" +
            "--开始领取令牌\n" +
            "local getTokens = tonumber(ARGV[1])\n" +
            "--是否需要立刻获得令牌 - 不预支\n" +
            "local flag = tonumber(ARGV[2] or 0)\n" +
            "--能拿走的令牌数\n" +
            "local canGetTokens = math.min(getTokens, hasTokens)\n" +
            "--需要预支的令牌数\n" +
            "local yuzhiTokens = getTokens - canGetTokens\n" +
            "--生成这些令牌需要的时间\n" +
            "local waitTime = singleTokenTime * yuzhiTokens\n" +
            "\n" +
            "--判断是否需要预支或者等待\n" +
            "if flag == 1 and nextTime + waitTime - nowTime > 0 then\n" +
            "    return -1\n" +
            "end\n" +
            " \n" +
            "--更新令牌桶的数据\n" +
            "redis.call(\"hmset\", key, \"hasTokens\", hasTokens - canGetTokens, \"nextTime\", nextTime + waitTime)\n" +
            "--返回结果\n" +
            "return nextTime + waitTime - nowTime";

    /**
     * 构造方法，生成令牌桶
     * @param maxTokens
     * @param secTokens
     */
    public TokenTong(String key, Integer maxTokens, Integer secTokens) {
        this.key = key;
        this.maxTokens = maxTokens;
        this.hasTokens = maxTokens;
        this.secTokens = secTokens;
        this.redisTemplate = SpringUtil.getBean(StringRedisTemplate.class);
        //初始化令牌桶
        initToken();
    }

    /**
     * 初始化令牌桶
     */
    private void initToken(){
        redisTemplate.execute(new DefaultRedisScript<>(initTokenLua, Long.class),
                Collections.singletonList(key),
                hasTokens + "", maxTokens + "", secTokens + "");
    }

    /**
     * 获取令牌
     * 参数为：需要的令牌数量
     * 返回值：需要等待的时间
     * @return
     */
    public long getToken(Integer tokens){
        //需要等待的时间
        long waitTime = redisTemplate.execute(new DefaultRedisScript<Long>(getTokenLua, Long.class),
                Collections.singletonList(key),
                tokens + "");

        if (waitTime > 0) {
            try {
                Thread.sleep(TimeUnit.MICROSECONDS.toMillis(waitTime));
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }

        return waitTime;
    }

    /**
     * 当前能否立刻获得令牌 - 不会进行令牌预支
     * @return
     */
    public boolean getTokenNow(Integer tokens){
        //需要等待的时间
        long waitTime = redisTemplate.execute(new DefaultRedisScript<Long>(getTokenLua, Long.class),
                Collections.singletonList(key),
                tokens + "", "1");

        return waitTime != -1 ? true : false;
    }

}
